[KSM] support keep sampling mask#7460
[KSM] support keep sampling mask#7460zeroRains wants to merge 8 commits intoPaddlePaddle:release/2.6from
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在为推理服务增加 keep sampling mask 输出能力:在 top_p/top_k 截断采样后,将每步保留下来的词表索引以稀疏形式返回/流式返回,便于客户端侧做可解释性与调试分析,并补充相应的 CLI 开关与端到端测试。
Changes:
- 新增启动参数
--enable-keep-sampling-mask,贯通 Engine/Worker/Sampler/TokenProcessor/OpenAI Serving 的开关传递。 - 在采样阶段计算稀疏 sampling_mask(以及 logZ),并在非
FD_USE_GET_SAVE_OUTPUT_V1路径通过 ZMQ side-channel 发送到 token_processor,再输出到 OpenAI 响应。 - 新增/更新单测与 e2e 测试覆盖 sampling_mask 在流式与非流式响应中的格式与一致性。
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/output/test_process_batch_output.py | 为测试构造的 processor 补齐 use_sampling_mask 字段初始化。 |
| tests/entrypoints/openai/test_max_streaming_tokens.py | 更新调用以适配 chat choice 新增的 sampling_mask_list 参数。 |
| tests/e2e/test_ernie_21b_mtp.py | e2e:启动参数开启 keep sampling mask,并新增流式/非流式/不同 top_p 的校验用例。 |
| fastdeploy/worker/worker_process.py | Worker CLI 新增 --enable-keep-sampling-mask(含下划线与短横线别名)。 |
| fastdeploy/worker/output.py | SamplerOutput 新增 sampling_mask 与 logz_per_batch 字段(稀疏 mask 与 logZ)。 |
| fastdeploy/worker/gpu_model_runner.py | 读取配置开关;非 V1 路径创建 sampling_mask ZMQ client;prepare_inputs 传 keep_sampling_mask;save_output 透传 sampling_mask_zmq_client。 |
| fastdeploy/output/token_processor.py | 非 V1 路径新增 sampling_mask ZMQ server;每步接收 mask 并写入 RequestOutput.outputs。 |
| fastdeploy/output/stream_transfer_data.py | StreamTransferData 新增 sampling_mask 字段以承载稀疏 mask。 |
| fastdeploy/model_executor/pre_and_post_process.py | stream transfer data 增加 sampling_mask;save_output_* 增加 side-channel 发送;新增基于 logZ 的 logprobs 归一化步骤。 |
| fastdeploy/model_executor/layers/sample/sampler.py | 新增 _compute_sampling_mask;normal 与 speculative 路径在采样前计算 sampling_mask/logZ 并写入 SamplerOutput。 |
| fastdeploy/model_executor/layers/sample/meta_data.py | SamplingMetadata 新增 keep_sampling_mask 字段。 |
| fastdeploy/model_executor/layers/sample/logprobs.py | build_output_logprobs 返回值新增 output_logits;新增 logprobs_renormalize_with_logz。 |
| fastdeploy/entrypoints/openai/serving_chat.py | 在 stream/full 响应中输出 sampling_mask;新增 _make_sampling_mask_list 并在 choice 汇总时扁平化。 |
| fastdeploy/entrypoints/openai/protocol.py | OpenAI 协议响应模型新增 sampling_mask 字段(List[List[int]])。 |
| fastdeploy/engine/request.py | CompletionOutput 新增 sampling_mask 字段并纳入 to_dict 输出。 |
| fastdeploy/engine/engine.py | worker_store_true_flag 增加 enable_keep_sampling_mask,启动 worker 时透传开关。 |
| fastdeploy/engine/common_engine.py | 同 engine.py:透传 enable_keep_sampling_mask 到 worker 启动参数。 |
| fastdeploy/engine/args_utils.py | EngineArgs/CLI 新增 --enable-keep-sampling-mask 参数与说明。 |
| fastdeploy/config.py | ModelConfig 新增 enable_keep_sampling_mask 默认字段。 |
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | ||
| real_bsz = model_output.accept_num.shape[0] | ||
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | ||
| mask_dict = {} | ||
| offset = 0 | ||
| total_masks = len(sampler_output.sampling_mask) | ||
| for i, n in enumerate(accept_nums): | ||
| n = max(int(n), 0) | ||
| if n > 0: | ||
| # List of n sparse index arrays, one per accepted token | ||
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | ||
| offset += n | ||
| if offset != total_masks: | ||
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") |
There was a problem hiding this comment.
Speculative 路径发送 sampling_mask 时用的是 model_output.accept_num 来做分组并按 i 构造 mask_dict key,但上面 speculate_save_output(_topk) 的输出会经过 index_to_batch_id + enable_pd_reorder 恢复到原始 batch 顺序;如果开启 PD reorder,这里未对 sampler_output.sampling_mask / accept_num / logz_per_batch 做一致的恢复排序,mask_dict 的 key/分组将与 token_processor 侧的 batch_id 不一致。建议:在生成 mask_dict 前先对 accept_num 与 sampling_mask 做与输出一致的 recover/reorder(可复用 recover_share_inputs["accept_num_cpu"] 或扩展 recover_batch_index_for_sampler_output),并同步重排 logz_per_batch。
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | |
| real_bsz = model_output.accept_num.shape[0] | |
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | |
| mask_dict = {} | |
| offset = 0 | |
| total_masks = len(sampler_output.sampling_mask) | |
| for i, n in enumerate(accept_nums): | |
| n = max(int(n), 0) | |
| if n > 0: | |
| # List of n sparse index arrays, one per accepted token | |
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | |
| offset += n | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| # Recover it to the same batch order as speculate_save_output(_topk) before grouping by request. | |
| real_bsz = recover_share_inputs["accept_num_cpu"].shape[0] | |
| raw_accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | |
| recovered_accept_nums = recover_share_inputs["accept_num_cpu"][:real_bsz].flatten().tolist() | |
| total_masks = len(sampler_output.sampling_mask) | |
| sampling_mask_groups = [] | |
| offset = 0 | |
| for n in raw_accept_nums: | |
| n = max(int(n), 0) | |
| sampling_mask_groups.append(sampler_output.sampling_mask[offset : offset + n]) | |
| offset += n | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| recovered_sampling_mask_groups = [[] for _ in range(real_bsz)] | |
| if model_output.index_to_batch_id is None: | |
| batch_id_map = list(range(real_bsz)) | |
| else: | |
| batch_id_map = np.asarray(model_output.index_to_batch_id[:real_bsz]).flatten().tolist() | |
| for i, group in enumerate(sampling_mask_groups): | |
| batch_id = int(batch_id_map[i]) | |
| if batch_id < 0 or batch_id >= real_bsz: | |
| raise ValueError(f"sampling_mask batch_id out of range: {batch_id}, real_bsz={real_bsz}") | |
| recovered_sampling_mask_groups[batch_id] = group | |
| mask_dict = {} | |
| for i, n in enumerate(recovered_accept_nums): | |
| n = max(int(n), 0) | |
| if len(recovered_sampling_mask_groups[i]) != n: | |
| raise ValueError( | |
| f"sampling_mask group size mismatch for batch {i}: " | |
| f"expected {n}, got {len(recovered_sampling_mask_groups[i])}" | |
| ) | |
| if n > 0: | |
| # List of n sparse index arrays, one per accepted token. | |
| mask_dict[i] = [arr.tolist() for arr in recovered_sampling_mask_groups[i]] |
| logz = paddle.to_tensor(logz, dtype=logprobs.dtype) | ||
| # Renormalize: log π_masked = log π_full - log Z_K | ||
| # Only normalize valid candidates; padding positions use -inf | ||
| valid_mask = paddle.isfinite(logprobs) | ||
| normalized_logprobs = paddle.where( | ||
| valid_mask, logprobs - logz.unsqueeze(1), paddle.full_like(logprobs, float("-inf")) | ||
| ) | ||
| # Update logprobs_tensors with normalized values | ||
| return LogprobsTensors( | ||
| logprob_token_ids=logprobs_tensors.logprob_token_ids, | ||
| logprobs=normalized_logprobs, | ||
| selected_token_ranks=logprobs_tensors.selected_token_ranks, | ||
| ) |
There was a problem hiding this comment.
logprobs_renormalize_with_logz 目前对所有 isfinite 的位置统一做 logprobs - logZ_K,但 logprobs_tensors 里的 top-k 项是从“全量分布”topk 取出的,未必全部落在 top_p/top_k 截断后的候选集合 K 内(尤其当 top_p 很小且 max_logprobs 较大时)。这会导致返回的“重归一化 logprobs”仍包含候选集之外 token 的有限值,不符合截断分布语义。建议结合 sampling_mask(或 candidate set)把不在 K 内的 token logprobs 置为 -inf / None,并仅对 K 内条目做重归一化,或改为直接在截断后的分布上构造 logprobs 输出。
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## release/2.6 #7460 +/- ##
==============================================
Coverage ? 73.87%
==============================================
Files ? 376
Lines ? 53130
Branches ? 8300
==============================================
Hits ? 39250
Misses ? 11129
Partials ? 2751
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 5 comments.
Comments suppressed due to low confidence (1)
tests/output/test_process_batch_draft_tokens.py:39
- 这里 cfg.model_config 是 MagicMock,若未显式设置 enable_keep_sampling_mask=False,TokenProcessor 可能把 keep_sampling_mask 当成开启并尝试创建/绑定 ZMQ IPC server(路径包含固定的 "9700"),在测试并发或重复执行时容易冲突。建议在 cfg.model_config 上补充
enable_keep_sampling_mask = False(除非本用例确实要覆盖该功能并做好 socket 清理/隔离)。
# 模拟 cfg
cfg = MagicMock()
cfg.speculative_config = MagicMock()
cfg.parallel_config.local_data_parallel_id = 0
cfg.parallel_config.engine_worker_queue_port = ["9700"]
cfg.speculative_config.method = "mtp"
cfg.speculative_config.num_speculative_tokens = 3
cfg.model_config = MagicMock()
cfg.model_config.enable_logprob = True
| # where the value is a list[int] or list[list[int]] of allowed token ids | ||
| sampling_masks_per_request = {} | ||
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | ||
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | ||
| if mask_data is not None and isinstance(mask_data, dict): | ||
| sampling_masks_per_request = mask_data | ||
|
|
There was a problem hiding this comment.
这里用 block=True 同步等待 sampling_mask side-channel 消息,缺少超时/降级路径:一旦 worker 未发送(例如 client 未创建/发送失败/某些 runner 未接入该 side-channel),TokenProcessor 会永久阻塞,导致整体推理挂死。建议改为非阻塞轮询(block=False)并在缺失时允许该 step 继续,或增加可配置超时并打印错误日志,避免死锁。
| # where the value is a list[int] or list[list[int]] of allowed token ids | |
| sampling_masks_per_request = {} | |
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=True) | |
| if mask_data is not None and isinstance(mask_data, dict): | |
| sampling_masks_per_request = mask_data | |
| # where the value is a list[int] or list[list[int]] of allowed token ids. | |
| # Use a non-blocking receive so a missing side-channel message does not | |
| # stall the whole token processing loop. | |
| sampling_masks_per_request = {} | |
| if self.use_sampling_mask and not envs.FD_USE_GET_SAVE_OUTPUT_V1 and hasattr(self, "sampling_mask_zmq_server"): | |
| mask_data = None | |
| try: | |
| _, mask_data = self.sampling_mask_zmq_server.receive_pyobj_once(block=False) | |
| except zmq.Again: | |
| mask_data = None | |
| except Exception: | |
| llm_logger.exception( | |
| "Failed to receive sampling_mask side-channel message; " | |
| "continuing without sampling mask for this step." | |
| ) | |
| mask_data = None | |
| if mask_data is not None: | |
| if isinstance(mask_data, dict): | |
| sampling_masks_per_request = mask_data | |
| else: | |
| llm_logger.warning( | |
| "Ignore invalid sampling_mask side-channel payload type: %s", | |
| type(mask_data).__name__, | |
| ) |
| def setup_method(self): | ||
| self.mock_cfg = MagicMock() | ||
| self.mock_cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] |
There was a problem hiding this comment.
这里的 cfg/model_config 使用 MagicMock 时,TokenProcessor.init 里 getattr(cfg.model_config, "enable_keep_sampling_mask", False) 会返回一个 truthy 的 MagicMock,导致单测意外开启 keep_sampling_mask 并尝试 bind 固定的 IPC 地址(/dev/shm/sampling_mask_output_rank_0_9700.socket),容易在并行/重复运行时出现“Address already in use”或资源泄漏。建议在 mock_cfg.model_config 上显式设置 enable_keep_sampling_mask=False(或 patch envs.FD_USE_GET_SAVE_OUTPUT_V1=True 以避免创建该 server)。
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.model_config.enable_keep_sampling_mask = False |
| """为 TokenProcessor 测试设置通用的 mock 对象。""" | ||
| self.mock_cfg = MagicMock() | ||
| self.mock_cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] |
There was a problem hiding this comment.
该 setUp 使用 MagicMock 构造 cfg 时同样存在 enable_keep_sampling_mask 被 MagicMock 误判为 True 的风险,TokenProcessor 可能在单测中意外创建并 bind sampling_mask 的 ZMQ IPC socket,造成端口/文件冲突和测试不稳定。建议显式设置 self.mock_cfg.model_config.enable_keep_sampling_mask = False。
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.parallel_config.engine_worker_queue_port = ["9700"] | |
| self.mock_cfg.model_config.enable_keep_sampling_mask = False |
| def setUp(self): | ||
| self.cfg = MagicMock() | ||
| self.cfg.model_config.enable_logprob = True | ||
| self.cfg.speculative_config.method = None | ||
| self.cfg.parallel_config.local_data_parallel_id = 0 | ||
| self.cfg.parallel_config.engine_worker_queue_port = ["9700"] | ||
| self.cached_generated_tokens = MagicMock() |
There was a problem hiding this comment.
该测试 cfg 通过 MagicMock 构造,TokenProcessor 初始化时可能将 enable_keep_sampling_mask 读取为 truthy 的 MagicMock,从而在单测里意外创建并 bind sampling_mask 的 ZMQ IPC server(固定 name/端口),导致用例间冲突或资源泄漏。建议在 cfg.model_config 上显式设置 enable_keep_sampling_mask=False。
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: | ||
| sampler_output.logprobs_tensors = logprobs_renormalize_with_logz( | ||
| sampler_output.logprobs_tensors.logprobs, | ||
| sampler_output.logz_per_batch, | ||
| sampler_output.logprobs_tensors, | ||
| ) |
There was a problem hiding this comment.
这里对 logprobs 做 renormalize 时需要避免与 Sampler.compute_logprobs 中的 top_p_normalized_logprobs 逻辑重复归一化;否则当请求侧已开启 top_p_normalized_logprobs(top_p!=1.0)时会出现二次减去 logZ,导致返回的 logprobs 数值错误。建议按 request/token 维度判断是否已做过 top_p 归一化,再决定是否应用 logz_per_batch(或仅对未归一化的行应用)。
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 6 comments.
Comments suppressed due to low confidence (1)
tests/output/test_process_batch_draft_tokens.py:39
- 该用例里 cfg/model_config 使用 MagicMock,未显式设置 enable_keep_sampling_mask 会导致 TokenProcessor 在 setUp 时尝试创建并 bind sampling_mask 的 ZMQ IPC server(固定 socket 文件名),从而引入测试间冲突/资源泄漏风险。建议在这里把 cfg.model_config.enable_keep_sampling_mask 显式设为 False(或在 teardown 关闭 server)。
# 模拟 cfg
cfg = MagicMock()
cfg.speculative_config = MagicMock()
cfg.parallel_config.local_data_parallel_id = 0
cfg.parallel_config.engine_worker_queue_port = ["9700"]
cfg.speculative_config.method = "mtp"
cfg.speculative_config.num_speculative_tokens = 3
cfg.model_config = MagicMock()
cfg.model_config.enable_logprob = True
| k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] | ||
| # boundary_idx = last True position (k-1), clamp for safety | ||
| boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] | ||
| boundary_prob = paddle.take_along_axis( | ||
| renorm_sorted_probs, | ||
| boundary_idx, | ||
| axis=-1, | ||
| ) # [B, 1] | ||
| topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) |
There was a problem hiding this comment.
_compute_sampling_mask() 里 boundary_idx 由 bool sum 得到的是 int32,直接传给 paddle.take_along_axis 可能触发索引 dtype 不兼容(Paddle 通常要求 int64 索引),导致启用 keep_sampling_mask 时运行时报错。建议在 take_along_axis 前显式把 boundary_idx cast 到 int64。
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
save_output_normal() 在 enable_pd_reorder=True 时会调用 recover_batch_index_for_sampler_output(),但该函数当前不会重排 sampler_output.sampling_mask(以及 logz_per_batch)。这样会导致 sampling_mask 与 recover 后的 sampled_token_ids / batch_id 对不上,返回给客户端的 sampling_mask 可能错配到其他 request。建议在 recover 流程里把 sampling_mask/logz_per_batch 也按 index_to_batch_id 同步重排,或在发送 mask_dict 前基于 index_to_batch_id 做一次 list 重排。
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | ||
| real_bsz = model_output.accept_num.shape[0] | ||
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | ||
| mask_dict = {} | ||
| offset = 0 | ||
| total_masks = len(sampler_output.sampling_mask) | ||
| for i, n in enumerate(accept_nums): | ||
| n = max(int(n), 0) | ||
| if n > 0: | ||
| # List of n sparse index arrays, one per accepted token | ||
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | ||
| offset += n | ||
| if offset != total_masks: | ||
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") |
There was a problem hiding this comment.
save_output_specualate() 里 sampling_mask 的分组/发送同样没有考虑 enable_pd_reorder 的 index_to_batch_id 重排(recover_batch_index_for_sampler_output 也不会处理 sampling_mask)。在开启 PD reorder 时这里会出现 per-request 的 sampling_mask 分发错位。建议在构造 mask_dict 前先对 sampling_mask 与 accept_num 的对齐关系做恢复(要么扩展 recover_batch_index_for_sampler_output 支持 sampling_mask/logz_per_batch,要么在这里显式按 index_to_batch_id 重排后再按 accept_num 分组)。
| # Group by request using accept_num so each entry is List[np.ndarray] (n arrays per req). | |
| real_bsz = model_output.accept_num.shape[0] | |
| accept_nums = model_output.accept_num[:real_bsz].flatten().tolist() | |
| mask_dict = {} | |
| offset = 0 | |
| total_masks = len(sampler_output.sampling_mask) | |
| for i, n in enumerate(accept_nums): | |
| n = max(int(n), 0) | |
| if n > 0: | |
| # List of n sparse index arrays, one per accepted token | |
| mask_dict[i] = [arr.tolist() for arr in sampler_output.sampling_mask[offset : offset + n]] | |
| offset += n | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| # It is flattened in sampler-output order, so when PD reorder is enabled we must | |
| # first restore the per-request chunk order with index_to_batch_id before grouping. | |
| real_bsz = model_output.accept_num.shape[0] | |
| accept_nums_for_sampling_mask = model_output.accept_num[:real_bsz].flatten().tolist() | |
| total_masks = len(sampler_output.sampling_mask) | |
| restored_sampling_mask = [[] for _ in range(real_bsz)] | |
| offset = 0 | |
| for sampler_idx, n in enumerate(accept_nums_for_sampling_mask): | |
| n = max(int(n), 0) | |
| next_offset = offset + n | |
| mask_chunk = sampler_output.sampling_mask[offset:next_offset] | |
| if len(mask_chunk) != n: | |
| raise ValueError( | |
| f"sampling_mask length mismatch while grouping: expected {n}, got {len(mask_chunk)} " | |
| f"for sampler_idx {sampler_idx}" | |
| ) | |
| if model_output.enable_pd_reorder: | |
| batch_id = int(model_output.index_to_batch_id[sampler_idx]) | |
| else: | |
| batch_id = sampler_idx | |
| restored_sampling_mask[batch_id] = mask_chunk | |
| offset = next_offset | |
| if offset != total_masks: | |
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | |
| mask_dict = {} | |
| for batch_id, mask_chunk in enumerate(restored_sampling_mask): | |
| if mask_chunk: | |
| # List of sparse index arrays, one per accepted token for this request. | |
| mask_dict[batch_id] = [arr.tolist() for arr in mask_chunk] |
| # Send sampling_mask via ZMQ side-channel when enabled. | ||
| if sampler_output.sampling_mask is not None and model_output.mp_rank == 0: | ||
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
这里在发送 sampling_mask 时未检查 sampling_mask_zmq_client 是否为 None。当前参数声明允许为 None,一旦上游配置/注入不一致(例如 keep_sampling_mask 被置真但 client 未初始化)会直接 AttributeError 并中断推理线程。建议在 send_pyobj 前加显式判空/断言,并给出更清晰的错误信息。
| # logZ_K for each request: log(sum(probs in candidate set K)) | ||
| # Used for renormalizing logprobs to match the truncated sampling distribution. | ||
| # Shape: [num_reqs] |
There was a problem hiding this comment.
SamplerOutput.logz_per_batch 的注释写的是“Shape: [num_reqs]”,但在 speculative decoding 路径里 logz_per_batch 实际是按 accepted token 展平计算的(shape 更接近 [total_accepted_tokens]),仅用于 logprobs 重归一化。建议更新注释/命名以反映两种路径的真实维度,避免后续误用。
| # logZ_K for each request: log(sum(probs in candidate set K)) | |
| # Used for renormalizing logprobs to match the truncated sampling distribution. | |
| # Shape: [num_reqs] | |
| # logZ_K used for logprob renormalization: | |
| # - Non-speculative decoding: per-request values with shape [num_reqs]. | |
| # - Speculative decoding: flattened per-accepted-token values with shape | |
| # approximately [total_accepted_tokens]. | |
| # Callers MUST NOT assume this is always shaped by num_reqs; interpret the | |
| # dimension according to the current decoding path. |
| # 1-D int32 numpy array of vocab indices retained by top_p/top_k for | ||
| # this request. Sparse format: only retained positions, not a dense | ||
| # vocab-sized bool mask. | ||
| sampling_mask: Optional[np.array] = None |
There was a problem hiding this comment.
StreamTransferData.sampling_mask 的类型注解写成了 Optional[np.array],但 np.array 是函数而不是类型;这里应使用 np.ndarray(或更具体的 np.ndarray[np.int32] 等)。否则会误导静态检查/IDE。
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-20 22:08:55
📋 Review 摘要
PR 概述:为 FastDeploy 新增 Keep Sampling Mask (KSM) 功能,在 top_p/top_k 采样时返回保留的词汇表索引稀疏列表,并支持 logprobs 重归一化。
变更范围:sampler、logprobs、pre_and_post_process、token_processor、serving_chat、protocol、config
影响面 Tag:[OP] [APIServer] [Engine]
📝 PR 规范检查
- PR 标题中的
[KSM]不在官方 Tag 列表中,应使用[Feature]。 - 本 PR 目标分支为
release/2.6(非 develop),根据规范应在标题前添加[Cherry-Pick]标签,并在末尾附上原 PR ID。
标题建议(可直接复制):
[Cherry-Pick][Feature] Support keep sampling mask for top_p/top_k candidate set(#原PR_ID)
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 🔴 Bug | pre_and_post_process.py:461 |
sampling_mask_zmq_client 可能为 None 时直接调用 send_pyobj |
| 🔴 Bug | pre_and_post_process.py:710 |
同上,speculative 路径存在相同风险 |
| 🟡 建议 | pre_and_post_process.py:386 |
logprobs 隐式重归一化可能导致已有用户 logprobs 语义变化 |
| 🟡 建议 | sampler.py:114 |
_compute_sampling_mask 每步执行全量 argsort,性能开销较大 |
总体评价
功能实现完整,覆盖了 MTP 和非 MTP 两种路径,测试用例覆盖了流式/非流式和 top_p 对比场景。主要问题是 ZMQ client 缺少空值防护可能导致运行时崩溃,以及 logprobs 重归一化对已有 --enable-logprob 用户是隐式的行为变更,建议显式控制。
| # sampling_mask is List[np.ndarray] of sparse int indices, one array per request. | ||
| mask_dict = {i: arr.tolist() for i, arr in enumerate(sampler_output.sampling_mask)} | ||
|
|
||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
🔴 Bug sampling_mask_zmq_client 参数类型为 Optional[ZmqIpcClient](默认 None),但此处未做空值检查直接调用 send_pyobj。
当前逻辑仅检查 sampler_output.sampling_mask is not None and model_output.mp_rank == 0,如果配置不一致(例如 enable_keep_sampling_mask=True 但 ZMQ client 因某些原因未初始化),将触发 AttributeError: 'NoneType' object has no attribute 'send_pyobj'。
建议修复:
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:| offset += n | ||
| if offset != total_masks: | ||
| raise ValueError(f"sampling_mask length mismatch: expected {offset}, got {total_masks}") | ||
| sampling_mask_zmq_client.send_pyobj(mask_dict) |
There was a problem hiding this comment.
🔴 Bug 与 save_output_normal 中相同的问题:sampling_mask_zmq_client 可能为 None(参数默认值为 None),调用前需添加空值防护。
建议修复:
if sampler_output.sampling_mask is not None and model_output.mp_rank == 0 and sampling_mask_zmq_client is not None:| ) | ||
|
|
||
| # Renormalize logprobs to match truncated sampling distribution (when enabled). | ||
| if sampler_output.logprobs_tensors is not None and sampler_output.logz_per_batch is not None: |
There was a problem hiding this comment.
🟡 建议 当同时启用 --enable-logprob 和 --enable-keep-sampling-mask 时,此处会将 logprobs 隐式地基于候选集重归一化。这改变了现有 logprobs 的语义(从全词表 log_softmax 变为截断候选集归一化),可能影响已依赖原始 logprobs 值的下游用户。
建议:
- 将重归一化作为独立的可选行为(例如
--logprobs-renormalize参数),或 - 在 API 响应中新增独立字段(如
normalized_logprobs)而非覆盖原logprobs,或 - 至少在文档/启动日志中明确说明
--enable-keep-sampling-mask会修改 logprobs 输出语义。
| top_p: paddle.Tensor, | ||
| top_k: Optional[paddle.Tensor] = None, | ||
| top_k_list: Optional[list] = None, | ||
| ) -> tuple[List[np.ndarray], np.ndarray]: |
There was a problem hiding this comment.
🟡 建议 _compute_sampling_mask 在每个 decode step 对 [B, vocab_size] 张量执行 argsort(O(B·V·logV)),对于大词表模型(如 GLM-4 的 151k vocab)开销显著。
而 top_k_top_p_sampling 内部已经执行了类似的排序逻辑。考虑:
- 复用
top_k_top_p_sampling中已有的排序结果来构建 mask,避免重复排序; - 或将 mask 计算与采样合并为一个 kernel 调用。
此外,tuple[List[np.ndarray], np.ndarray] 使用了 Python 3.9+ 的小写 tuple 语法,如项目需兼容 3.8 应改为 Tuple[List[np.ndarray], np.ndarray]。
Motivation
本 PR 为 FastDeploy 实现 Keep Sampling Mask (KSM) 功能,用于在 top_p/top_k 采样过程中返回保留的词汇表索引列表(稀疏格式)。
当前推理引擎在执行 top_p/top_k 采样时,仅返回最终采样的 token ID,但不提供采样过程中的候选集合信息。这导致:
本 PR 通过新增 sampling_mask 字段,记录每个 token 采样时保留的词汇表索引(稀疏格式),并提供基于候选集合的 logprobs 重归一化功能。
Modifications
sampler.py 下新增_compute_sampling_mask方法
添加启动参数--enable-keep-sampling-mask
logprobs.py 下新增logz的renormalize函数,
logprobs_renormalize_with_logzpre_and_post_process.py的post_processs中调用renormalize函数
Usage or Command
服务启动指令:
Accuracy Tests
yes
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.